/*!
 * Copyright (c) Alibaba, Inc. and its affiliates.
 * @file    binary.cu
 */

#include "allspark.pb.h"
#include "cuda_common.h"  // NOLINT
#include "cuda_kernel.h"
#include "elementwise.cuh"

namespace allspark {
namespace cuda {

template <typename T>
struct AddFunctor {
  __device__ __host__ __forceinline__ T operator()(T x, T y) const {
    return x + y;
  }
};
template <typename T>
struct MulFunctor {
  __device__ __host__ __forceinline__ T operator()(T x, T y) const {
    return x * y;
  }
};
template <typename T>
struct FusedMulAdd1Functor {
  __device__ __host__ __forceinline__ T operator()(T x, T y) const {
    return x * y + (T)1;
  }
};
template <typename T>
struct GeGLUFunctor {
  __device__ __host__ __forceinline__ T operator()(T x, T y) const {
    // GeGLU(x, y) = GeLU(x) * y, where x and y have already been generated by
    // Gemm

    // direct version, slow:
    // return T(float(x) * 0.5f * (1.0f + erff(float(x) * 0.70710678f))) * y;

    // tanh simulation version, fast:
    T cdf = 0.5f * (1.0f + tanhf(((T)0.7978845608028654f *
                                  (T)(y + (T)0.044715f * y * y * y))));
    y *= cdf;
    return x * y;
  }
};
template <typename T>
struct SwiGLUFunctor {
  __device__ __host__ __forceinline__ T operator()(T x, T y) const {
    y = y * (1.0f / (1.0f + expf((T)(-y))));
    return x * y;
  }
};

template <typename T>
void BinaryKernelLauncher(T* out, const T* in1, const T* in2, int64_t count,
                          int type, cudaStream_t stream) {
  switch (type) {
    case BinaryType::ADD:
      elementwise::Binary(AddFunctor<T>(), count, out, in1, in2, stream);
      break;
    case BinaryType::MUL:
      elementwise::Binary(MulFunctor<T>(), count, out, in1, in2, stream);
      break;
    case BinaryType::FUSED_MUL_ADD_1:
      elementwise::Binary(FusedMulAdd1Functor<T>(), count, out, in1, in2,
                          stream);
      break;
    case BinaryType::GEGLU:
      elementwise::Binary(GeGLUFunctor<T>(), count, out, in1, in2, stream);
      break;
    case BinaryType::SWIGLU:
      elementwise::Binary(SwiGLUFunctor<T>(), count, out, in1, in2, stream);
      break;
    default:

      return;
  }
}

template void BinaryKernelLauncher<float>(float* out, const float* in1,
                                          const float* in2, int64_t count,
                                          int type, cudaStream_t stream);
template void BinaryKernelLauncher<int>(int* out, const int* in1,
                                        const int* in2, int64_t count, int type,
                                        cudaStream_t stream);
#ifdef ENABLE_FP16
template void BinaryKernelLauncher<half>(half* out, const half* in1,
                                         const half* in2, int64_t count,
                                         int type, cudaStream_t stream);
#endif
template void BinaryKernelLauncher<hie::bfloat16>(hie::bfloat16* out,
                                                  const hie::bfloat16* in1,
                                                  const hie::bfloat16* in2,
                                                  int64_t count, int type,
                                                  cudaStream_t stream);
}  // namespace cuda
}  // namespace allspark
